Skip to content

Conversation

@bnellnm
Copy link
Collaborator

@bnellnm bnellnm commented Jun 4, 2025

Add activation chunking logic to FusedMoEModularKernel. This can help with memory usage or kernels that can't handle large M, e.g. triton fused_experts.

This version of the PR only supports chunking for the triton kernels. I'm going to try to add it for DeepGemm as a follow up.

cc @simon-mo , @varun-sundar-rabindranath

@github-actions
Copy link

github-actions bot commented Jun 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @bnellnm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team, Gemini here with a summary of this pull request. This PR introduces chunking logic into the modular Triton kernel implementation for Mixture-of-Experts (MoE) layers. The primary goal is to address potential issues (like the one referenced in #5938) that might arise when processing a large number of tokens simultaneously by breaking the computation down into smaller, manageable chunks. The kernel now processes tokens in segments determined by the VLLM_FUSED_MOE_CHUNK_SIZE environment variable.

Highlights

  • Modular Triton MoE Chunking: Implemented chunking logic within the modular_triton_fused_moe kernel to handle large inputs more effectively, processing tokens in chunks.
  • Configurable Chunk Size: The size of the processing chunks is controlled by the VLLM_FUSED_MOE_CHUNK_SIZE environment variable.
  • Dynamic Cache Sizing: Intermediate caches and kernel configurations are adjusted dynamically based on the size of the current chunk being processed, particularly for the final, potentially smaller chunk.
  • Added Test Coverage: A new test case has been added to verify the correctness of the modular_triton_fused_moe kernel with the new chunking logic by comparing its output against existing reference implementations.

Changelog

Click here to see the changelog
  • tests/kernels/moe/test_moe.py
    • Imported modular_triton_fused_moe (lines 18-19).
    • Added a new test block for modular_triton_fused_moe, initializing the kernel and running it with test data (lines 107-121).
    • Added an assertion to compare the output of the modular Triton kernel (m_triton_output) with the reference torch_output (lines 124-127).
  • vllm/model_executor/layers/fused_moe/fused_moe.py
    • Added comments explaining the purpose of chunking and referencing issue #5938 (lines 1602-1603).
    • Introduced CHUNK_SIZE from envs.VLLM_FUSED_MOE_CHUNK_SIZE and calculated M = min(num_tokens, CHUNK_SIZE) (lines 1604-1605).
    • Modified the call to try_get_optimal_moe_config to use functools.partial and then call it with M (lines 1612-1622).
    • Resized intermediate caches (intermediate_cache1, intermediate_cache2, intermediate_cache3) based on M (lines 1638-1641).
    • Introduced a for loop to iterate over token chunks (line 1643).
    • Sliced hidden_states and topk_ids for the current chunk within the loop (lines 1647-1648).
    • Added logic to adjust intermediate cache sizes and recalculate the config for the last chunk if it's smaller than CHUNK_SIZE (lines 1654-1663).
    • Moved the invoke_fused_moe_kernel calls (for w1 and w2) inside the chunking loop to process data chunk by chunk (lines 1669-1688 and 1699-1718).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


Tokens flow like streams,
Chunked kernels chase their dreams,
Faster inference.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces chunking logic to the modular Triton MoE kernel (TritonExperts.apply) to address a known issue. The intention to make the kernel more modular is good. However, the current implementation of the chunking logic within TritonExperts.apply has a critical correctness issue, as it does not accumulate the results from the different chunks. Additionally, there is a high severity issue related to the initial sizing of intermediate caches.

Summary of Findings

  • Missing accumulation in chunking logic: The chunking logic added to TritonExperts.apply does not accumulate the results from each chunk, leading to incorrect outputs for large inputs.
  • Incorrect intermediate cache sizing: Intermediate caches are sized based on the first chunk size (or total size if smaller than chunk size), which is incorrect when the total input size is larger than the chunk size.
  • Environment variable for CHUNK_SIZE: The CHUNK_SIZE is controlled by an environment variable, which could impact configurability and transparency.

Merge Readiness

The pull request introduces a critical correctness issue by not accumulating results from chunks in the modular kernel. This must be addressed before merging. The high severity issue regarding cache sizing should also be fixed. The medium severity issue is a suggestion for improvement. I am unable to approve this pull request; please have other reviewers review and approve this code after the necessary changes are made.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

2 similar comments
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@bnellnm bnellnm marked this pull request as ready for review June 5, 2025 14:46
@bnellnm bnellnm changed the title [Kernels] Add chunking logic to modular triton kernel [Kernels] Add activation chunking logic to FusedMoEModularKernel Jun 5, 2025
@mergify
Copy link

mergify bot commented Jun 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 7, 2025
@varun-sundar-rabindranath
Copy link
Contributor

Hey @bnellnm, Left some comments, mostly nits except for #19168 (comment) - I think it could be a bug. But may be i'm missing.

@mergify
Copy link

mergify bot commented Jun 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bnellnm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 9, 2025
bnellnm added 2 commits June 9, 2025 19:42
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
bnellnm added 4 commits June 9, 2025 19:44
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
@bnellnm bnellnm force-pushed the triton-moe-chunk branch from 8a3abe5 to 32704dc Compare June 9, 2025 19:45
@mergify mergify bot removed the needs-rebase label Jun 9, 2025
return c3[c_map].view(M, topk, K)
# We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'd be nice if we could detect this case - we could do it later as an optimization 👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I'm not sure how to do that though. At least this should be no worse than it was before.

Signed-off-by: Bill Nell <bnell@redhat.com>
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 9, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) June 9, 2025 22:32
bnellnm added 2 commits June 10, 2025 21:00
…date test utils

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
auto-merge was automatically disabled June 10, 2025 21:13

Head branch was pushed to by a user without write access

bnellnm added 2 commits June 11, 2025 01:41
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
@tlrmchlsmth tlrmchlsmth merged commit 29fa5ca into vllm-project:main Jun 11, 2025
72 checks passed
Comment on lines +469 to +477
for chunk in range(num_chunks):
begin_chunk_idx = chunk * CHUNK_SIZE
end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M)
begin_out_idx = chunk * OUT_CHUNK_SIZE
end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out)
curr_a1q = a1q[begin_chunk_idx:end_chunk_idx]
curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx,
end_chunk_idx)
curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bnellnm @tlrmchlsmth I'm not completely sure that vLLM-compile works with this. I don't have a repro for this yet, but if M is the batch size, vLLM-compile wants to treat it as dynamic, but range(num_chunks) can cause a specialization on it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is disabling the chunking equivalent to the state before this PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the repro: #19631

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants